{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yV8Q-dayhexF"
      },
      "source": [
        "# SEEDS Demo\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/google-research/blob/master/seeds/SEEDS_Demo.ipynb)\n",
        "\n",
        "**Enable TPU to run the notebook**\n",
        "\n",
        "- This demo colab runs on the free hosted TPU colab kernel.\n",
        "- Navigate to the menu item `Runtime` → `Change runtime type`.\n",
        "- Select `TPU` and click `Save`.\n",
        "- Click on `Connect` on the top right. You are ready once you see  `✓TPU` next to the RAM \u0026 Disk display."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lulVWjnBFPC8"
      },
      "source": [
        "## Copyright notice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sTyXfhe8FC5e"
      },
      "source": [
        "Copyright 2023 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License.\n",
        "\n",
        "\n",
        "---\n",
        "\n",
        "\n",
        "This is not an official Google product.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OCrcQYenXCQk"
      },
      "source": [
        "# Preparation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pxAdPiFJE2bE"
      },
      "outputs": [],
      "source": [
        "# @title Install external packages\n",
        "!pip install ecmwflibs cfgrib eccodes\n",
        "!pip install cartopy matplotlib numpy pandas scipy seaborn tqdm\n",
        "!pip install xarray[complete]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iK5sR3OFZi9F"
      },
      "outputs": [],
      "source": [
        "# @title Install SEEDS package\n",
        "%%shell\n",
        "git clone -n --depth=1 --filter=tree:0 https://github.com/google-research/google-research\n",
        "cd google-research\n",
        "git sparse-checkout set --no-cone seeds\n",
        "git checkout\n",
        "cd seeds\n",
        "pip install --no-deps ."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oSXgLijXjtUS"
      },
      "source": [
        "# Basic SEEDS using example data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rmpAoFEmyZy5"
      },
      "outputs": [],
      "source": [
        "# @title Imports\n",
        "import os\n",
        "\n",
        "import cartopy.crs as ccrs\n",
        "import cartopy.util\n",
        "import gcsfs\n",
        "import matplotlib as mpl\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import seaborn as sns\n",
        "import tensorflow as tf\n",
        "import tqdm.notebook as tqdm\n",
        "import xarray as xr\n",
        "\n",
        "from seeds import grid_lib\n",
        "\n",
        "sns.set_theme(context='paper', style='white', font_scale=1.2)\n",
        "sns.set_palette('colorblind')\n",
        "fs = gcsfs.GCSFileSystem('anon')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8SiZqdlnWg0Q"
      },
      "outputs": [],
      "source": [
        "# @title Initialize TPU\n",
        "resolver = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
        "tf.config.experimental_connect_to_cluster(resolver)\n",
        "tf.tpu.experimental.initialize_tpu_system(resolver)\n",
        "strategy = tf.distribute.TPUStrategy(resolver)\n",
        "print(f'Found: {strategy.num_replicas_in_sync} TPUs.')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dbZmrPNdpvgG"
      },
      "outputs": [],
      "source": [
        "# @title List all available checkpoints\n",
        "base_dir = 'gs://gresearch/seeds'\n",
        "\n",
        "checkpoint_dir = f'{base_dir}/model_checkpoints'\n",
        "for path in tf.io.gfile.glob(checkpoint_dir + '/*'):\n",
        "  if not path.endswith('_$folder$'):\n",
        "    print(os.path.basename(path))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3X81ZShBuUS8"
      },
      "source": [
        "Checkpoint naming convention:\n",
        "- `gee_c2_s7`: SEEDS-GEE trained conditioning on 2 seeds for 7-day lead time.\n",
        "- `gpp_c2_s7_g3_r4`: SEEDS-GPP trained conditioning on 2 seeds for 7-day lead time, where the label mixture is 3 GEFS members and 4 ERA reanalyses."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "ztm_qJGzqcQK"
      },
      "outputs": [],
      "source": [
        "# @title Choose a model checkpoint\n",
        "checkpoint = \"gee_c2_s7\" # @param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSjf69FQtOFs"
      },
      "outputs": [],
      "source": [
        "# @title Load the checkpoint\n",
        "name_parts = checkpoint.split('_')\n",
        "num_seeds = int(name_parts[1][1:])\n",
        "lead_days = int(name_parts[2][1:])\n",
        "print('Number of seeds:', num_seeds)\n",
        "print('Lead time (days):', lead_days)\n",
        "\n",
        "with strategy.scope():\n",
        "  model = tf.saved_model.load(f'{checkpoint_dir}/{checkpoint}')\n",
        "print('Model total number of parameters:', sum([tf.size(var) for var in model._variables]).numpy())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_xunp9QW_2LP"
      },
      "outputs": [],
      "source": [
        "# @title Inspect the GEFS data for 2022 already regridded to the cubed sphere at 2 degrees.\n",
        "gefs = xr.open_zarr(fs.get_mapper(f'{base_dir}/data/gefs_forecast_2022_cubedsphere.zarr'))\n",
        "gefs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QRTdbv3QRRr5"
      },
      "outputs": [],
      "source": [
        "# @title Load one GEFS snapshot and create the wrapper for plotting\n",
        "base_time = pd.Timestamp('2022-01-01')\n",
        "\n",
        "snapshot = gefs.sel(time=base_time, number=0, step=lead_days)['anomaly'].load()\n",
        "grid = grid_lib.CubedSphere.on(snapshot.data)\n",
        "plot_gridder = grid.plot_gridder()\n",
        "\n",
        "def wrap(data):\n",
        "  rec = plot_gridder(data)\n",
        "  new_rec, new_lon = cartopy.util.add_cyclic_point(rec, rec.longitude)\n",
        "  cyclic = xr.DataArray(new_rec, coords={'latitude': rec.latitude, 'longitude': new_lon})\n",
        "  return cyclic"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I3qZMAQNBF9n"
      },
      "outputs": [],
      "source": [
        "# @title Plot the forecasted anomaly from one GEFS member\n",
        "fig, axes = plt.subplots(4, 2, figsize=(10, 12), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))\n",
        "lon, lat = gefs.longitude.data, gefs.latitude.data\n",
        "for i, ax in enumerate(axes.flat):\n",
        "  wrap(snapshot.data[i]).plot(cmap='Spectral', transform=ccrs.PlateCarree(), ax=ax, add_colorbar=False)\n",
        "  ax.set_title(snapshot.field.data[i])\n",
        "  ax.coastlines()\n",
        "fig.subplots_adjust(wspace=0, hspace=0.1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9rPbXoN6HXGv"
      },
      "outputs": [],
      "source": [
        "# @title Load the base time climatology\n",
        "climatology = xr.open_zarr(fs.get_mapper(f'{base_dir}/data/climatology_cubedsphere.zarr')).load()\n",
        "monthday = base_time.month * 100 + base_time.day\n",
        "clim_mean = climatology.sel(monthday=monthday)['mean'].load()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5cNMwCsXJMwd"
      },
      "source": [
        "Note: SEEDS uses a fixed day-of-year climatology. The model uses the **base time** climatology as the input and samples **valid time** anomalies. So to convert the outputs back to the raw values, the valid time climatology also need to be loaded (this will be done in the later part of this colab)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qIoQC9C-KRLI"
      },
      "outputs": [],
      "source": [
        "# @title Create conditining information\n",
        "# Take the first num_seeds GEFS anomalies as seeds.\n",
        "cond_anomaly = gefs.sel(time=base_time, number=np.arange(num_seeds), step=lead_days)['anomaly'].load().data\n",
        "# Concatenate those with the climatology to get the conditioning input\n",
        "cond_clim_mean = clim_mean.data\n",
        "cond = np.concatenate([cond_anomaly, cond_clim_mean[None, ...]], axis=0)\n",
        "print('Conditioning shape (#inputs, #fields, #locations) =',cond.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "myIWyBYoLEiL"
      },
      "outputs": [],
      "source": [
        "# @title Utility functions for distributing data across accelerators\n",
        "def distribute(strategy, arr):\n",
        "  if arr.shape[0] % strategy.num_replicas_in_sync != 0:\n",
        "    raise ValueError('The batch size should be a multiple of num_replicas_in_sync.')\n",
        "  local_size = arr.shape[0] // strategy.num_replicas_in_sync\n",
        "  def value_fn(ctx):\n",
        "    k = ctx.replica_id_in_sync_group\n",
        "    return tf.cast(arr[k * local_size:(k + 1) * local_size], tf.float32)\n",
        "  return strategy.experimental_distribute_values_from_function(value_fn)\n",
        "\n",
        "def split(strategy, arr):\n",
        "  if arr.shape[0] % strategy.num_replicas_in_sync != 0:\n",
        "    raise ValueError('The batch size should be a multiple of num_replicas_in_sync.')\n",
        "  def value_fn(ctx):\n",
        "    return arr[ctx.replica_id_in_sync_group]\n",
        "  return strategy.experimental_distribute_values_from_function(value_fn)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "07yxwFRfLEiL"
      },
      "outputs": [],
      "source": [
        "# @title Generate more ensemble members\n",
        "batchsize = 2 * strategy.num_replicas_in_sync\n",
        "# The sampling function is completely deterministic for a fixed model_rng. So\n",
        "# each replica should have its own unique model_rng.\n",
        "model_rng = tf.constant(np.arange(strategy.num_replicas_in_sync) + 42, tf.int64)\n",
        "# Reducing num_diffusion_steps leads to faster generation but might degrade quality.\n",
        "num_diffusion_steps = tf.constant(600, tf.int64)\n",
        "min_diffusion_time = tf.constant(1e-3, tf.float32)\n",
        "\n",
        "# To generate batchsize samples at a time, we duplicate cond as a batch.\n",
        "tiled_cond = tf.cast(tf.tile(cond[None, ...], (batchsize, 1, 1, 1)), tf.float32)\n",
        "\n",
        "# Run the sampler.\n",
        "dist_model_rng = split(strategy, model_rng)\n",
        "dist_conditioning = distribute(strategy, tiled_cond)\n",
        "samples = strategy.run(model.sample, args=(dist_conditioning, dist_model_rng, num_diffusion_steps, min_diffusion_time))\n",
        "samples = strategy.gather(samples, axis=0).numpy()\n",
        "\n",
        "print('Samples shape: (#samples, 1, #fields, #locations) =', samples.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TGUKZXWxLEiL"
      },
      "outputs": [],
      "source": [
        "# @title Plot the generated results\n",
        "field_id = 1\n",
        "plot_opts = dict(cmap='Spectral', transform=ccrs.PlateCarree(), add_colorbar=False)\n",
        "\n",
        "seeds = cond[:num_seeds, field_id]\n",
        "vmin, vmax = seeds.min() * 0.9, seeds.max() * 0.9\n",
        "\n",
        "fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))\n",
        "for i, ax in enumerate(axes[0].flat):\n",
        "  if i \u003c num_seeds:\n",
        "    wrap(seeds[i]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)\n",
        "    ax.coastlines()\n",
        "    ax.set_title(f'Cond #{i+1}')\n",
        "for i, ax in enumerate(axes[1:].flat):\n",
        "  if i \u003c batchsize:\n",
        "    wrap(samples[i, 0, field_id]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)\n",
        "    ax.coastlines()\n",
        "    ax.set_title(f'Generated #{i+1}')\n",
        "fig.subplots_adjust(wspace=0, hspace=0.1)\n",
        "fig.suptitle(f'field={gefs.field.data[field_id]}');"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xT-ZS4A1_73K"
      },
      "outputs": [],
      "source": [
        "# @title Use the valid time climatology to map the anomalies to raw values\n",
        "valid_time = base_time + pd.Timedelta(days=lead_days)\n",
        "valid_monthday = valid_time.month * 100 + valid_time.day\n",
        "valid_clim = climatology.sel(monthday=valid_monthday).load()\n",
        "clim_mean = valid_clim['mean'].data\n",
        "clim_std = valid_clim['std'].data\n",
        "\n",
        "raw_samples = samples * clim_std + clim_mean\n",
        "raw_cond = cond * clim_std + clim_mean"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZWXgDAjL_7nU"
      },
      "outputs": [],
      "source": [
        "# @title Plot the generated results in raw values\n",
        "field_id = 4\n",
        "plot_opts = dict(cmap='Reds', transform=ccrs.PlateCarree(), add_colorbar=False)\n",
        "\n",
        "seeds = raw_cond[:num_seeds, field_id]\n",
        "vmin, vmax = seeds.min() * 0.9, seeds.max() * 0.9\n",
        "\n",
        "fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))\n",
        "for i, ax in enumerate(axes[0].flat):\n",
        "  if i \u003c num_seeds:\n",
        "    wrap(seeds[i]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)\n",
        "    ax.coastlines()\n",
        "    ax.set_title(f'Cond #{i+1}')\n",
        "for i, ax in enumerate(axes[1:].flat):\n",
        "  if i \u003c batchsize:\n",
        "    wrap(raw_samples[i, 0, field_id]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)\n",
        "    ax.coastlines()\n",
        "    ax.set_title(f'Generated #{i+1}')\n",
        "fig.subplots_adjust(wspace=0, hspace=0.1)\n",
        "fig.suptitle(f'field={gefs.field.data[field_id]}');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dcVkiGVEShKu"
      },
      "source": [
        "# Using SEEDS with live operational GEFS data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O4T8B2H_nyhd"
      },
      "source": [
        "Read the operational GEFS data on Google Cloud published by NOAA. For more information, see the [website](https://console.cloud.google.com/marketplace/product/noaa-public/gfs-ensemble-forecast-system).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CvhgXfgIhYnE"
      },
      "outputs": [],
      "source": [
        "# @title Operational GEFS data reader code\n",
        "def download_file(path):\n",
        "  local_name = os.path.basename(path)\n",
        "  if not os.path.exists(local_name):\n",
        "    tf.io.gfile.copy(path, local_name)\n",
        "\n",
        "def make_operational_gefs_aws_url(date, lead_days, number, file='pgrb2a'):\n",
        "  prefix = f'gep{number:02}' if number \u003e 0 else 'gec00'\n",
        "  # SEEDS models are trained only on the 00 hour forecast.\n",
        "  return f'gs://gfs-ensemble-forecast-system/gefs.{date}/00/atmos/{file}p5/{prefix}.t00z.{file}.0p50.f{24 * lead_days}'\n",
        "\n",
        "def make_hPa_getter(in_name, out_name, levels):\n",
        "  def get(x):\n",
        "    x = x[in_name].sel(isobaricInhPa=levels)\n",
        "    x['isobaricInhPa'] = [f'{out_name}_{level}hPa' for level in levels]\n",
        "    return x.to_dataset('isobaricInhPa')\n",
        "  return get\n",
        "\n",
        "def load_grib(path, getters):\n",
        "  coords = {'latitude', 'longitude', 'step', 'time'}\n",
        "  res = []\n",
        "  for selector, getter in getters:\n",
        "    res.append(getter(xr.open_dataset(path, engine='cfgrib', filter_by_keys=selector)))\n",
        "  res = xr.merge([part.drop(set(part.coords.keys()) - coords) for part in res])\n",
        "  return res.load()\n",
        "\n",
        "def load_gefs_grib(path, file='pgrb2a'):\n",
        "  if file == 'pgrb2a':\n",
        "    getters = [\n",
        "      ({'paramId': 167}, lambda x: x.rename({'t2m': 't_2m'})),\n",
        "      ({'paramId': 3054}, lambda x: x.rename({'pwat': 'tcwv'})),\n",
        "      ({'paramId': 130}, make_hPa_getter('t', 't', [850])),\n",
        "      ({'paramId': 131}, make_hPa_getter('u', 'u', [850])),\n",
        "      ({'paramId': 132}, make_hPa_getter('v', 'v', [850])),\n",
        "      ({'paramId': 156}, make_hPa_getter('gh', 'z', [500])),\n",
        "    ]\n",
        "  else:\n",
        "    getters = [\n",
        "      ({'paramId': 151}, lambda x: x.rename({'msl': 'msl'})),\n",
        "      ({'paramId': 133, 'typeOfLevel': 'isobaricInhPa'}, make_hPa_getter('q', 'q', [500])),\n",
        "    ]\n",
        "  return load_grib(path, getters)\n",
        "\n",
        "g = 9.80665  # Gravitational acceleration.\n",
        "\n",
        "def gefs_to_era5_units(ds):\n",
        "  for field in ds.data_vars:\n",
        "    if field.startswith('z_'):\n",
        "      ds[field] = ds[field] * 9.80665 # Unit: gpm -\u003e dm\n",
        "  return ds"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "CvOFlNv8arCM"
      },
      "outputs": [],
      "source": [
        "# @title Choose a base time and and lead time\n",
        "base_date = \"20231026\" # @param {type:\"string\"}\n",
        "lead_days = 7 # @param {type:\"integer\"}\n",
        "\n",
        "print('valid_date is', (pd.Timestamp(base_date) + pd.Timedelta(days=lead_days)).strftime('%Y%m%d'))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m3vq0LHLb5Sp"
      },
      "source": [
        "Here to save time we only download the first 8 members. For the best results, download all 31 GEFS members instead."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b1CXCvA_jW44"
      },
      "outputs": [],
      "source": [
        "# @title Load GEFS data\n",
        "forecasts = []\n",
        "for number in range(8):  # To download all GEFS members, change this to 31.\n",
        "  forecast = []\n",
        "  for file in ['pgrb2a', 'pgrb2b']:\n",
        "    url = make_operational_gefs_aws_url(base_date, lead_days, number, file)\n",
        "    print(f'Fetch {url}...', flush=True)\n",
        "    download_file(url)\n",
        "    filename = os.path.basename(url)\n",
        "    forecast.append(load_gefs_grib(filename, file=file))\n",
        "  forecast = xr.merge(forecast).assign_coords({'number': number})\n",
        "  forecasts.append(forecast)\n",
        "forecasts = xr.concat(forecasts, 'number')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "boa5-4couA9l"
      },
      "outputs": [],
      "source": [
        "# @title Process the raw GEFS to a single tensor\n",
        "# Change to ERA5 units.\n",
        "forecasts = gefs_to_era5_units(forecasts)\n",
        "# Make sure the fields follow the order in the model.\n",
        "fields = [field.decode() for field in model.field_tags.numpy()]\n",
        "forecasts = forecasts[fields]\n",
        "forecasts = forecasts.to_array('field').transpose('number', 'field', 'latitude', 'longitude')\n",
        "print('(#members, #fields, #lats, #lons) =', forecasts.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bq7PaJMWvfvK"
      },
      "outputs": [],
      "source": [
        "# @title Regrid to cubed sphere at 2 degrees (48 nodes for 90 degrees)\n",
        "source_grid = grid_lib.Equirectangular.on(forecasts)\n",
        "target_grid = grid_lib.CubedSphere(48)\n",
        "gridder = source_grid.to(target_grid)\n",
        "coords = {k: v.data for k, v in forecasts.coords.items()}\n",
        "lon, lat = target_grid.grid_points\n",
        "coords['latitude'] = ('values', lat)\n",
        "coords['longitude'] = ('values', lon)\n",
        "gridded = xr.DataArray(gridder(forecasts.data), dims=['number', 'field', 'values'], coords=coords)\n",
        "print('(#members, #fields, #grid_points) =', gridded.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nV6-J2eFwbQn"
      },
      "outputs": [],
      "source": [
        "# @title Get the climatologies and convert to anomalies\n",
        "base_time = pd.Timestamp(base_date)\n",
        "valid_time = base_time + pd.Timedelta(days=lead_days)\n",
        "base_clim_mean = climatology.sel(monthday=base_time.month * 100 + base_time.day)['mean'].load().data\n",
        "valid_clim_mean = climatology.sel(monthday=valid_time.month * 100 + valid_time.day)['mean'].load().data\n",
        "valid_clim_std = climatology.sel(monthday=valid_time.month * 100 + valid_time.day)['std'].load().data\n",
        "anomalies = (gridded - valid_clim_mean[None, ...]) / valid_clim_std[None, ...]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0psvZvcpdC3M"
      },
      "source": [
        "Here to save time we sample 4 rounds. In each round we generate 16 samples from 2 random seeds out of the 8 downloaded GEFS members. We get 4*16=64 samples in total.\n",
        "\n",
        "To get the best result, download the full 31 member ensemble before and do many rounds. For example, we can do 32 rounds and generate 16 samples from random seeds out of the 31 to get 32*16=512 samples. Scaling up to more TPUs can make this significantly faster."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "amYoGIeTU0_l"
      },
      "outputs": [],
      "source": [
        "# @title Generate more ensemble members\n",
        "rounds = 4\n",
        "samples_per_round = 2 * strategy.num_replicas_in_sync  # This 16 in this demo.\n",
        "rng = np.random.default_rng(42)\n",
        "num_diffusion_steps = tf.constant(600, tf.int64)\n",
        "min_diffusion_time = tf.constant(1e-3, tf.float32)\n",
        "\n",
        "src_ensemble_size = anomalies.shape[0]\n",
        "results = []\n",
        "for _ in tqdm.tqdm(range(rounds)):\n",
        "  seeds_idx = rng.choice(8, size=num_seeds, replace=False)\n",
        "  query = np.concatenate([anomalies[seeds_idx], base_clim_mean[None, ...]], axis=0)\n",
        "  tiled_cond = tf.tile(query[None, ...], (samples_per_round, 1, 1, 1))\n",
        "\n",
        "  model_rng = tf.constant(rng.integers(0, 2 ** 10, size=strategy.num_replicas_in_sync), tf.int64)\n",
        "  dist_model_rng = split(strategy, model_rng)\n",
        "  dist_conditioning = distribute(strategy, tiled_cond)\n",
        "  samples = strategy.run(model.sample, args=(dist_conditioning, dist_model_rng, num_diffusion_steps, min_diffusion_time))\n",
        "  samples = strategy.gather(samples, axis=0).numpy()\n",
        "  results.append(samples)\n",
        "results = np.concatenate(results, axis=0)\n",
        "print('Samples shape: (#samples, 1, #fields, #locations) =', results.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "96mbZiNckw0V"
      },
      "outputs": [],
      "source": [
        "# @title Map back to raw values\n",
        "results_raw = results[:, 0] * valid_clim_std[None, ...] + valid_clim_mean[None, ...]\n",
        "gefs_raw = gridded.data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cHycy4Peq5GL"
      },
      "outputs": [],
      "source": [
        "# @title Plot the generated results in raw values\n",
        "field_id = 6\n",
        "plot_opts = dict(cmap='Blues', transform=ccrs.PlateCarree(), add_colorbar=False)\n",
        "\n",
        "seeds = gefs_raw[:num_seeds, field_id]\n",
        "vmin, vmax = seeds.min() * 0.9, seeds.max() * 0.9\n",
        "\n",
        "fig, axes = plt.subplots(4, 3, figsize=(12, 10), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))\n",
        "for i, ax in enumerate(axes[0].flat):\n",
        "  if i \u003c num_seeds:\n",
        "    wrap(seeds[i]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)\n",
        "    ax.coastlines()\n",
        "    ax.set_title(f'Cond #{i+1}')\n",
        "for i, ax in enumerate(axes[1:].flat):\n",
        "  if i \u003c batchsize:\n",
        "    wrap(results_raw[i, field_id]).plot(vmin=vmin, vmax=vmax, ax=ax, **plot_opts)\n",
        "    ax.coastlines()\n",
        "    ax.set_title(f'Generated #{i+1}')\n",
        "fig.subplots_adjust(wspace=0, hspace=0.1)\n",
        "fig.suptitle(f'field={gefs.field.data[field_id]}');"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NYx-xhunnmVS"
      },
      "outputs": [],
      "source": [
        "# @title Postage stamp charts over Europe (first 2 rows are from GEFS, the next 4 rows are generated)\n",
        "fig, axes = plt.subplots(6, 4, figsize=(12, 18), subplot_kw=dict(projection=ccrs.LambertConformal(5, 48)))\n",
        "levels = 14\n",
        "\n",
        "def level_styler(low, mid, high):\n",
        "  lowbar = 8\n",
        "  return [low] * lowbar + [mid] + [high] * (14 - lowbar)\n",
        "\n",
        "vmin, vmax = None, None\n",
        "fixed_levels = levels\n",
        "zplot_opts = dict(cmap='Spectral_r', add_colorbar=False, transform=ccrs.PlateCarree())\n",
        "pplot_opts = dict(\n",
        "    transform=ccrs.PlateCarree(),\n",
        "    colors='darkslategray',\n",
        "    linewidths=level_styler(1, 1.5, 1),\n",
        "    linestyles=level_styler('dashed', 'solid', 'solid'),\n",
        ")\n",
        "for i, ax in enumerate(axes.flat):\n",
        "  if i // 4 \u003c 2:\n",
        "    ensemble = gefs_raw\n",
        "    start = 0\n",
        "  else:\n",
        "    ensemble = results_raw\n",
        "    start = 8\n",
        "  ax.set_extent((-40, 50, 10, 86), crs=ccrs.PlateCarree())\n",
        "  zplot = wrap(ensemble[i - start, 2] / g).plot(ax=ax, vmin=vmin, vmax=vmax, **zplot_opts)\n",
        "  if vmin is None:\n",
        "    vmin, vmax = zplot.get_clim()\n",
        "  pplot = wrap(ensemble[i - start, 0]).plot.contour(ax=ax, levels=fixed_levels, **pplot_opts)\n",
        "  if isinstance(fixed_levels, int):\n",
        "    fixed_levels = pplot.levels\n",
        "  ax.coastlines(resolution='110m', linewidth=1.5)\n",
        "fig.subplots_adjust(wspace=0, hspace=0)\n",
        "cbar_ax = fig.add_axes([0.96, 0.3, 0.02, 0.3])\n",
        "mpl.colorbar.ColorbarBase(cbar_ax, orientation='vertical', cmap='Spectral_r', norm=mpl.colors.Normalize(vmin, vmax), extend='both')\n",
        "cbar_ax.set_title('Geopotential at 500hPa height (m)', rotation='vertical', x=-0.7, y=0.15)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Kfrd8Nx1xDY4"
      },
      "source": [
        "# Advanced usages"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MZ8qQoH-xuTw"
      },
      "source": [
        "The forward SDE is\n",
        "\n",
        "$$\n",
        "dX_t = g(t)\\,dW_t.\n",
        "$$\n",
        "where $W_t$ is the standard Wiener process. The diffusion coefficient is given by $g(t)=b^t$ where the constant $b$ is the noise schedule base exponent. This is a Gaussian process with\n",
        "\n",
        "$$\n",
        "(X_t|X_0=x_0) \\sim N(x_0, \\sigma^2(t)I), \\qquad \\sigma^2(t)=\\int_0^t g^2(s)\\,ds.\n",
        "$$\n",
        "\n",
        "Let $p(t,x)$ be the probability density for $X_t$. The reverse SDE is\n",
        "\n",
        "$$\n",
        "dY_t = -g^2(t)\\nabla \\log p(t,Y_t)\\,dt+g(t)\\,d\\bar{W}_t,\n",
        "$$\n",
        "\n",
        "where $\\bar{W}_t$ is the reverse Wiener process. In diffusion modeling, we use a neural net $s_\\theta(t,x)$ to approximate $\\nabla \\log p(t,Y_t)$."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LQriUn1QtOCm"
      },
      "outputs": [],
      "source": [
        "# @title Plot the noise schedule\n",
        "print('SDE noise schedule base exponent:', model.sde_base_exponent.numpy())\n",
        "diffusion_time = tf.linspace(1e-3, 1.0, 32)\n",
        "diffusion_coef = model.diffusion_coef(diffusion_time)\n",
        "marginal_std = model.marginal_std(diffusion_time)\n",
        "\n",
        "fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharex=True)\n",
        "ax = axes[0]\n",
        "ax.plot(diffusion_time, diffusion_coef, label='Diffusion coefficient')\n",
        "ax.set_title('Diffusion coefficient $g(t)$')\n",
        "ax = axes[1]\n",
        "ax.plot(diffusion_time, marginal_std)\n",
        "ax.set_title('Marginal std $\\sigma(t)$');"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HpSyHvOd0nWw"
      },
      "source": [
        "Because $(X_t|X_0=x_0) \\sim N(x_0, \\sigma^2(t)I)$, we see that $x_0+\\sigma(t)Z$ for $Z\\sim N(0,I)$ has the same distribution as $X_t$. The neural score function is trained on the denoising loss\n",
        "\n",
        "$$\n",
        "\\mathbb{E}_{t\\sim U(0,1]}\\mathbb{E}_{x\\sim p_{\\textrm{data}}(x)}\\mathbb{E}_{Z\\sim N(0,I)}||s_\\theta(t,x+\\sigma(t)Z)\\sigma(t) + Z||_2^2.\n",
        "$$\n",
        "\n",
        "Hence we for any $x\\sim p_{\\textrm{data}}(x)$, $Z\\sim N(0,I)$, $t\\in (0,1]$, we expect\n",
        "\n",
        "$$\n",
        "D(t, x+\\sigma(t)Z):= x\n",
        "$$\n",
        "\n",
        "is a denoiser. We can thus visualize the learned score function by looking at the corresponding denoiser."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ijbd3QW3ATif"
      },
      "outputs": [],
      "source": [
        "# @title Evaluate the denoiser at diffusion times for some random perturbations\n",
        "batchsize = strategy.num_replicas_in_sync\n",
        "\n",
        "# This continues from before using the 2022-01-01 example data, where we took\n",
        "# the first num_seeds GEFS member for conditioning. Thus, here we take the\n",
        "# last GEFS member as the label for denoising.\n",
        "label = gefs.sel(time='2022-01-01', step=lead_days).isel(number=slice(-1, None, None))['anomaly'].load().data\n",
        "label = tf.tile(tf.cast(label, tf.float32)[None, ...], (batchsize, 1, 1, 1))\n",
        "tiled_cond = tf.tile(tf.cast(cond, tf.float32)[None, ...], (batchsize, 1, 1, 1))\n",
        "\n",
        "# Create batchsize copies of perturbed samples at different diffusion times.\n",
        "diffusion_time = tf.linspace(1e-2, 1.0, batchsize)\n",
        "marginal_std = model.marginal_std(diffusion_time)\n",
        "noise = tf.random.normal((batchsize, 1) + label.shape[2:])\n",
        "perturbed = label + noise * marginal_std[:, None, None, None]\n",
        "\n",
        "# Evaluate the model score function and compute the denoiser result.\n",
        "dist_conditioning = distribute(strategy, tiled_cond)\n",
        "dist_diffusion_time = distribute(strategy, diffusion_time)\n",
        "dist_perturbed = distribute(strategy, perturbed)\n",
        "scores = strategy.run(model.score, args=(dist_conditioning, dist_diffusion_time, dist_perturbed))\n",
        "scores = strategy.gather(scores, axis=0).numpy()\n",
        "\n",
        "denoised = perturbed + scores * (marginal_std ** 2)[:, None, None, None]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sA1X6SQWqXEx"
      },
      "outputs": [],
      "source": [
        "# @title Plot the denoiser results\n",
        "field_id = 1\n",
        "\n",
        "fig, axes = plt.subplots(8, 3, figsize=(12, 16), sharex=True, sharey=True, subplot_kw=dict(projection=ccrs.Robinson()))\n",
        "for i in range(batchsize):\n",
        "  truth = label[i, 0, field_id].numpy()\n",
        "  vmin, vmax = truth.min(), truth.max()\n",
        "  for j, data in enumerate([label[i, 0, field_id], perturbed[i, 0, field_id], denoised[i, 0, field_id]]):\n",
        "    wrap(data.numpy()).plot(vmin=vmin, vmax=vmax, cmap='Spectral', transform=ccrs.PlateCarree(), ax=axes[i][j], add_colorbar=False)\n",
        "  axes[i][0].set_ylabel(f't={diffusion_time[i]:.2f}')\n",
        "axes[0][0].set_title('Label')\n",
        "axes[0][1].set_title('Perturbed')\n",
        "axes[0][2].set_title('Denoised')\n",
        "fig.suptitle(f'Denoising plots for the field {gefs.field.data[field_id]}')\n",
        "fig.subplots_adjust(wspace=0, hspace=0.02)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jaz84vp8dMJC"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "lulVWjnBFPC8"
      ],
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1xD-Q2EfLqSM0JXcoog5cZ8enXwrJU_F5",
          "timestamp": 1702467958882
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
